{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/Bonus_UnsupervisedLearning/Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neuromatch Academy: Week 3, Day 5, Tutorial 1\n",
    "# Deep Learning 2: Autoencoders\n",
    "\n",
    "__Content creators:__ Marco Brigham and the [CCNSS](https://www.ccnss.org/) team (2014-2018)\n",
    "\n",
    "__Content reviewers:__ Itzel Olivos, Karen Schroeder, Karolina Stosio, Kshitij Dwivedi, Spiros Chavlis, Michael Waskom"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Tutorial Objectives\n",
    "\n",
    "## Internal representations and autoencoders\n",
    "How can simple algorithms capture relevant aspects of data and build robust models of the world?\n",
    "\n",
    "Autoencoders are a family of artificial neural networks (ANNs) that learn internal representations through auxiliary tasks, i.e., *learning by doing*.\n",
    "\n",
    "The primary task is to reconstruct output images based on a compressed representation of the inputs. This task teaches the network which details to throw away while still producing images that are similar to the inputs.\n",
    "\n",
    "&nbsp; \n",
    "\n",
    "A fictitious *MNIST cognitive task* bundles more elaborate tasks such as removing noise from images, guessing occluded parts, and recovering original image orientation. We use the handwritten digits from the MNIST dataset since it is easier to identify similar images or issues with reconstructions than in other types of data, such as spiking data time series.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![MNIST cognitive task](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/mnist_task.png)\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "The beauty of autoencoders is the possibility to see these internal representations. The bottleneck layer enforces data compression by having fewer units than input and output layers. Further limiting this layer to two or three units enables us to see how the autoencoder is organizing the data internally in two or three-dimensional *latent space*. \n",
    "\n",
    "&nbsp; \n",
    "\n",
    "Our roadmap is the following: learn about typical elements of autoencoder architecture in Tutorial 1 (this tutorial), how to extend their performance in Tutorial 2, and use them to solve the MNIST cognitive task in Tutorial 3.\n",
    "\n",
    "&nbsp; \n",
    "\n",
    "In this tutorial, you will:\n",
    "- Get acquainted with latent space visualizations and apply them to *Principal Component Analysis (PCA)* and *Non-negative Matrix Factorization (NMF)*\n",
    "- Build and train a single hidden layer ANN autoencoder\n",
    "- Inspect the representational power of autoencoders with latent spaces of different dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:35.468604Z",
     "iopub.status.busy": "2021-05-25T00:59:35.467886Z",
     "iopub.status.idle": "2021-05-25T00:59:35.526114Z",
     "shell.execute_reply": "2021-05-25T00:59:35.525563Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Video 1: Intro\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"FBTHsDCrXcU\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Setup\n",
    "Please execute the cell(s) below to initialize the notebook environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:35.530235Z",
     "iopub.status.busy": "2021-05-25T00:59:35.529700Z",
     "iopub.status.idle": "2021-05-25T00:59:38.570471Z",
     "shell.execute_reply": "2021-05-25T00:59:38.569413Z"
    }
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "\n",
    "from sklearn import decomposition\n",
    "from sklearn.datasets import fetch_openml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:38.584172Z",
     "iopub.status.busy": "2021-05-25T00:59:38.580720Z",
     "iopub.status.idle": "2021-05-25T00:59:38.618024Z",
     "shell.execute_reply": "2021-05-25T00:59:38.616985Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Figure settings\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")\n",
    "fig_w, fig_h = plt.rcParams['figure.figsize']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:38.653266Z",
     "iopub.status.busy": "2021-05-25T00:59:38.636576Z",
     "iopub.status.idle": "2021-05-25T00:59:38.669760Z",
     "shell.execute_reply": "2021-05-25T00:59:38.669051Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Helper functions\n",
    "\n",
    "\n",
    "def downloadMNIST():\n",
    "  \"\"\"\n",
    "  Download MNIST dataset and transform it to torch.Tensor\n",
    "\n",
    "  Args:\n",
    "    None\n",
    "\n",
    "  Returns:\n",
    "    x_train : training images (torch.Tensor) (60000, 28, 28)\n",
    "    x_test  : test images (torch.Tensor) (10000, 28, 28)\n",
    "    y_train : training labels (torch.Tensor) (60000, )\n",
    "    y_train : test labels (torch.Tensor) (10000, )\n",
    "  \"\"\"\n",
    "  X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame = False)\n",
    "  # Trunk the data\n",
    "  n_train = 60000\n",
    "  n_test = 10000\n",
    "\n",
    "  train_idx = np.arange(0, n_train)\n",
    "  test_idx = np.arange(n_train, n_train + n_test)\n",
    "\n",
    "  x_train, y_train = X[train_idx], y[train_idx]\n",
    "  x_test, y_test = X[test_idx], y[test_idx]\n",
    "\n",
    "  # Transform np.ndarrays to torch.Tensor\n",
    "  x_train = torch.from_numpy(np.reshape(x_train,\n",
    "                                        (len(x_train),\n",
    "                                         28, 28)).astype(np.float32))\n",
    "  x_test = torch.from_numpy(np.reshape(x_test,\n",
    "                                       (len(x_test),\n",
    "                                        28, 28)).astype(np.float32))\n",
    "\n",
    "  y_train = torch.from_numpy(y_train.astype(int))\n",
    "  y_test = torch.from_numpy(y_test.astype(int))\n",
    "\n",
    "  return (x_train, y_train, x_test, y_test)\n",
    "\n",
    "\n",
    "def init_weights_kaiming_uniform(layer):\n",
    "  \"\"\"\n",
    "  Initializes weights from linear PyTorch layer\n",
    "  with kaiming uniform distribution.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  # check for linear PyTorch layer\n",
    "  if isinstance(layer, nn.Linear):\n",
    "    # initialize weights with kaiming uniform distribution\n",
    "    nn.init.kaiming_uniform_(layer.weight.data)\n",
    "\n",
    "\n",
    "def init_weights_kaiming_normal(layer):\n",
    "  \"\"\"\n",
    "  Initializes weights from linear PyTorch layer\n",
    "  with kaiming normal distribution.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  # check for linear PyTorch layer\n",
    "  if isinstance(layer, nn.Linear):\n",
    "    # initialize weights with kaiming normal distribution\n",
    "    nn.init.kaiming_normal_(layer.weight.data)\n",
    "\n",
    "\n",
    "def get_layer_weights(layer):\n",
    "  \"\"\"\n",
    "  Retrieves learnable parameters from PyTorch layer.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    list with learnable parameters\n",
    "  \"\"\"\n",
    "  # initialize output list\n",
    "  weights = []\n",
    "\n",
    "  # check whether layer has learnable parameters\n",
    "  if layer.parameters():\n",
    "    # copy numpy array representation of each set of learnable parameters\n",
    "    for item in layer.parameters():\n",
    "      weights.append(item.detach().numpy().copy())\n",
    "\n",
    "  return weights\n",
    "\n",
    "\n",
    "def eval_mse(y_pred, y_true):\n",
    "  \"\"\"\n",
    "  Evaluates Mean Square Error (MSE) between y_pred and y_true\n",
    "\n",
    "  Args:\n",
    "    y_pred (torch.Tensor)\n",
    "        prediction samples\n",
    "\n",
    "    v (numpy array of floats)\n",
    "        ground truth samples\n",
    "\n",
    "  Returns:\n",
    "    MSE(y_pred, y_true)\n",
    "  \"\"\"\n",
    "\n",
    "  with torch.no_grad():\n",
    "      criterion = nn.MSELoss()\n",
    "      loss = criterion(y_pred, y_true)\n",
    "\n",
    "  return float(loss)\n",
    "\n",
    "\n",
    "def eval_bce(y_pred, y_true):\n",
    "  \"\"\"\n",
    "  Evaluates Binary Cross-Entropy (BCE) between y_pred and y_true\n",
    "\n",
    "  Args:\n",
    "    y_pred (torch.Tensor)\n",
    "        prediction samples\n",
    "\n",
    "    v (numpy array of floats)\n",
    "        ground truth samples\n",
    "\n",
    "  Returns:\n",
    "    BCE(y_pred, y_true)\n",
    "  \"\"\"\n",
    "\n",
    "  with torch.no_grad():\n",
    "      criterion = nn.BCELoss()\n",
    "      loss = criterion(y_pred, y_true)\n",
    "\n",
    "  return float(loss)\n",
    "\n",
    "\n",
    "def plot_weights_ab(encoder_w_a, encoder_w_b, decoder_w_a, decoder_w_b,\n",
    "                    label_a='init', label_b='train',\n",
    "                    bins_encoder=0.5, bins_decoder=1.5):\n",
    "  \"\"\"\n",
    "  Plots row of histograms with encoder and decoder weights\n",
    "  between two training checkpoints.\n",
    "\n",
    "  Args:\n",
    "    encoder_w_a (iterable)\n",
    "        encoder weights at checkpoint a\n",
    "\n",
    "    encoder_w_b (iterable)\n",
    "        encoder weights at checkpoint b\n",
    "\n",
    "    decoder_w_a (iterable)\n",
    "        decoder weights at checkpoint a\n",
    "\n",
    "    decoder_w_b (iterable)\n",
    "        decoder weights at checkpoint b\n",
    "\n",
    "    label_a (string)\n",
    "        label for checkpoint a\n",
    "\n",
    "    label_b (string)\n",
    "        label for checkpoint b\n",
    "\n",
    "    bins_encoder (float)\n",
    "        norm of extreme values for encoder bins\n",
    "\n",
    "    bins_decoder (float)\n",
    "        norm of extreme values for decoder bins\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  plt.figure(figsize=(fig_w * 1.2, fig_h * 1.2))\n",
    "\n",
    "  # plot encoder weights\n",
    "  bins = np.linspace(-bins_encoder, bins_encoder, num=32)\n",
    "\n",
    "  plt.subplot(221)\n",
    "  plt.title('Encoder weights to unit 0')\n",
    "  plt.hist(encoder_w_a[0].flatten(), bins=bins, alpha=0.3, label=label_a)\n",
    "  plt.hist(encoder_w_b[0].flatten(), bins=bins, alpha=0.3, label=label_b)\n",
    "  plt.legend()\n",
    "\n",
    "  plt.subplot(222)\n",
    "  plt.title('Encoder weights to unit 1')\n",
    "  plt.hist(encoder_w_a[1].flatten(), bins=bins, alpha=0.3, label=label_a)\n",
    "  plt.hist(encoder_w_b[1].flatten(), bins=bins, alpha=0.3, label=label_b)\n",
    "  plt.legend()\n",
    "\n",
    "  # plot decoder weights\n",
    "  bins = np.linspace(-bins_decoder, bins_decoder, num=32)\n",
    "\n",
    "  plt.subplot(223)\n",
    "  plt.title('Decoder weights from unit 0')\n",
    "  plt.hist(decoder_w_a[:, 0].flatten(), bins=bins, alpha=0.3, label=label_a)\n",
    "  plt.hist(decoder_w_b[:, 0].flatten(), bins=bins, alpha=0.3, label=label_b)\n",
    "  plt.legend()\n",
    "\n",
    "  plt.subplot(224)\n",
    "  plt.title('Decoder weights from unit 1')\n",
    "  plt.hist(decoder_w_a[:, 1].flatten(), bins=bins, alpha=0.3, label=label_a)\n",
    "  plt.hist(decoder_w_b[:, 1].flatten(), bins=bins, alpha=0.3, label=label_b)\n",
    "  plt.legend()\n",
    "\n",
    "  plt.tight_layout()\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "def plot_row(images, show_n=10, image_shape=None):\n",
    "  \"\"\"\n",
    "  Plots rows of images from list of iterables (iterables: list, numpy array\n",
    "  or torch.Tensor). Also accepts single iterable.\n",
    "  Randomly selects images in each list element if item count > show_n.\n",
    "\n",
    "  Args:\n",
    "    images (iterable or list of iterables)\n",
    "        single iterable with images, or list of iterables\n",
    "\n",
    "    show_n (integer)\n",
    "        maximum number of images per row\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image if vectorized form\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if not isinstance(images, (list, tuple)):\n",
    "    images = [images]\n",
    "\n",
    "  for items_idx, items in enumerate(images):\n",
    "\n",
    "    items = np.array(items)\n",
    "    if items.ndim == 1:\n",
    "      items = np.expand_dims(items, axis=0)\n",
    "\n",
    "    if len(items) > show_n:\n",
    "      selected = np.random.choice(len(items), show_n, replace=False)\n",
    "      items = items[selected]\n",
    "\n",
    "    if image_shape is not None:\n",
    "      items = items.reshape([-1] + list(image_shape))\n",
    "\n",
    "    plt.figure(figsize=(len(items) * 1.5, 2))\n",
    "    for image_idx, image in enumerate(items):\n",
    "\n",
    "      plt.subplot(1, len(items), image_idx + 1)\n",
    "      plt.imshow(image, cmap='gray', vmin=image.min(), vmax=image.max())\n",
    "      plt.axis('off')\n",
    "\n",
    "      plt.tight_layout()\n",
    "\n",
    "\n",
    "def xy_lim(x):\n",
    "  \"\"\"\n",
    "  Return arguments for plt.xlim and plt.ylim calculated from minimum\n",
    "  and maximum of x.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        data to be plotted\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  x_min = np.min(x, axis=0)\n",
    "  x_max = np.max(x, axis=0)\n",
    "\n",
    "  x_min = x_min - np.abs(x_max - x_min) * 0.05 - np.finfo(float).eps\n",
    "  x_max = x_max + np.abs(x_max - x_min) * 0.05 + np.finfo(float).eps\n",
    "\n",
    "  return [x_min[0], x_max[0]], [x_min[1], x_max[1]]\n",
    "\n",
    "\n",
    "def plot_generative(x, decoder_fn, image_shape, n_row=16):\n",
    "  \"\"\"\n",
    "  Plots images reconstructed by decoder_fn from a 2D grid in\n",
    "  latent space that is determined by minimum and maximum values in x.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    decoder_fn (integer)\n",
    "        function returning vectorized images from 2D latent space coordinates\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    n_row\n",
    "        number of rows in grid\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  xlim, ylim = xy_lim(np.array(x))\n",
    "\n",
    "  dx = (xlim[1] - xlim[0]) / n_row\n",
    "  grid = [np.linspace(ylim[0] + dx / 2, ylim[1] - dx / 2, n_row),\n",
    "          np.linspace(xlim[0] + dx / 2, xlim[1] - dx / 2, n_row)]\n",
    "\n",
    "  canvas = np.zeros((image_shape[0]*n_row, image_shape[1] * n_row))\n",
    "\n",
    "  cmap = plt.get_cmap('gray')\n",
    "\n",
    "  for j, latent_y in enumerate(grid[0][::-1]):\n",
    "    for i, latent_x in enumerate(grid[1]):\n",
    "\n",
    "      latent = np.array([[latent_x, latent_y]], dtype=np.float32)\n",
    "      with torch.no_grad():\n",
    "          x_decoded = decoder_fn(torch.from_numpy(latent))\n",
    "\n",
    "      x_decoded = x_decoded.reshape(image_shape)\n",
    "\n",
    "      canvas[j*image_shape[0]: (j + 1) * image_shape[0],\n",
    "             i*image_shape[1]: (i + 1) * image_shape[1]] = x_decoded\n",
    "\n",
    "  plt.imshow(canvas, cmap=cmap, vmin=canvas.min(), vmax=canvas.max())\n",
    "  plt.axis('off')\n",
    "\n",
    "\n",
    "def plot_latent(x, y, show_n=500, fontdict=None, xy_labels=None):\n",
    "  \"\"\"\n",
    "  Plots digit class of each sample in 2D latent space coordinates.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    n_row (integer)\n",
    "        number of samples\n",
    "\n",
    "    fontdict (dictionary)\n",
    "        optional style option for plt.text\n",
    "\n",
    "    xy_labels (list)\n",
    "        optional list with [xlabel, ylabel]\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if fontdict is None:\n",
    "    fontdict = {'weight': 'bold', 'size': 12}\n",
    "\n",
    "  cmap = plt.get_cmap('tab10')\n",
    "\n",
    "  if len(x) > show_n:\n",
    "    selected = np.random.choice(len(x), show_n, replace=False)\n",
    "    x = x[selected]\n",
    "    y = y[selected]\n",
    "\n",
    "  for my_x, my_y in zip(x, y):\n",
    "    plt.text(my_x[0], my_x[1], str(int(my_y)),\n",
    "             color=cmap(int(my_y) / 10.),\n",
    "             fontdict=fontdict,\n",
    "             horizontalalignment='center',\n",
    "             verticalalignment='center',\n",
    "             alpha=0.8)\n",
    "\n",
    "  if xy_labels is None:\n",
    "    xy_labels = ['$Z_1$', '$Z_2$']\n",
    "\n",
    "  plt.xlabel(xy_labels[0])\n",
    "  plt.ylabel(xy_labels[1])\n",
    "\n",
    "  xlim, ylim = xy_lim(np.array(x))\n",
    "  plt.xlim(xlim)\n",
    "  plt.ylim(ylim)\n",
    "\n",
    "\n",
    "def plot_latent_generative(x, y, decoder_fn, image_shape, title=None,\n",
    "                           xy_labels=None):\n",
    "  \"\"\"\n",
    "  Two horizontal subplots generated with encoder map and decoder grid.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    decoder_fn (integer)\n",
    "        function returning vectorized images from 2D latent space coordinates\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    title (string)\n",
    "        plot title\n",
    "\n",
    "    xy_labels (list)\n",
    "        optional lsit with [xlabel, ylabel]\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  fig = plt.figure(figsize=(12, 6))\n",
    "\n",
    "  if title is not None:\n",
    "    fig.suptitle(title, y=1.05)\n",
    "\n",
    "  ax = fig.add_subplot(121)\n",
    "  ax.set_title('Encoder map', y=1.05)\n",
    "  plot_latent(x, y, xy_labels=xy_labels)\n",
    "\n",
    "  ax = fig.add_subplot(122)\n",
    "  ax.set_title('Decoder grid', y=1.05)\n",
    "  plot_generative(x, decoder_fn, image_shape)\n",
    "\n",
    "  plt.tight_layout()\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "def plot_latent_ab(x1, x2, y, selected_idx=None,\n",
    "                   title_a='Before', title_b='After', show_n=500):\n",
    "  \"\"\"\n",
    "  Two horizontal subplots with encoder maps.\n",
    "\n",
    "  Args:\n",
    "    x1 (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space (left plot)\n",
    "\n",
    "    x2 (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample (right plot)\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    selected_idx (list of integers)\n",
    "        indexes of elements to be plotted\n",
    "\n",
    "    show_n (integer)\n",
    "        maximum number of samples in each plot\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  fontdict = {'weight': 'bold', 'size': 12}\n",
    "\n",
    "  if len(x1) > show_n:\n",
    "\n",
    "    if selected_idx is None:\n",
    "      selected_idx = np.random.choice(len(x1), show_n, replace=False)\n",
    "\n",
    "    x1 = x1[selected_idx]\n",
    "    x2 = x2[selected_idx]\n",
    "    y = y[selected_idx]\n",
    "\n",
    "  plt.figure(figsize=(12, 6))\n",
    "\n",
    "  ax = plt.subplot(121)\n",
    "  ax.set_title(title_a, y=1.05)\n",
    "  plot_latent(x1, y, fontdict=fontdict)\n",
    "\n",
    "  ax = plt.subplot(122)\n",
    "  ax.set_title(title_b, y=1.05)\n",
    "  plot_latent(x2, y, fontdict=fontdict)\n",
    "\n",
    "  plt.tight_layout()\n",
    "\n",
    "\n",
    "def runSGD(net, input_train, input_test, criterion='bce',\n",
    "           n_epochs=10, batch_size=32, verbose=False):\n",
    "  \"\"\"\n",
    "  Trains autoencoder network with stochastic gradient descent with Adam\n",
    "  optimizer and loss criterion. Train samples are shuffled, and loss is\n",
    "  displayed at the end of each opoch for both MSE and BCE. Plots training loss\n",
    "  at each minibatch (maximum of 500 randomly selected values).\n",
    "\n",
    "  Args:\n",
    "    net (torch network)\n",
    "        ANN object (nn.Module)\n",
    "\n",
    "    input_train (torch.Tensor)\n",
    "        vectorized input images from train set\n",
    "\n",
    "    input_test (torch.Tensor)\n",
    "        vectorized input images from test set\n",
    "\n",
    "    criterion (string)\n",
    "        train loss: 'bce' or 'mse'\n",
    "\n",
    "    n_epochs (boolean)\n",
    "        number of full iterations of training data\n",
    "\n",
    "    batch_size (integer)\n",
    "        number of element in mini-batches\n",
    "\n",
    "    verbose (boolean)\n",
    "        whether to print final loss\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize loss function\n",
    "  if criterion == 'mse':\n",
    "    loss_fn = nn.MSELoss()\n",
    "  elif criterion == 'bce':\n",
    "    loss_fn = nn.BCELoss()\n",
    "  else:\n",
    "    print('Please specify either \"mse\" or \"bce\" for loss criterion')\n",
    "\n",
    "  # Initialize SGD optimizer\n",
    "  optimizer = optim.Adam(net.parameters())\n",
    "\n",
    "  # Placeholder for loss\n",
    "  track_loss = []\n",
    "\n",
    "  print('Epoch', '\\t', 'Loss train', '\\t', 'Loss test')\n",
    "  for i in range(n_epochs):\n",
    "\n",
    "    shuffle_idx = np.random.permutation(len(input_train))\n",
    "    batches = torch.split(input_train[shuffle_idx], batch_size)\n",
    "\n",
    "    for batch in batches:\n",
    "\n",
    "      output_train = net(batch)\n",
    "      loss = loss_fn(output_train, batch)\n",
    "      optimizer.zero_grad()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "\n",
    "      # Keep track of loss at each epoch\n",
    "      track_loss += [float(loss)]\n",
    "\n",
    "    loss_epoch = f'{i + 1} / {n_epochs}'\n",
    "    with torch.no_grad():\n",
    "      output_train = net(input_train)\n",
    "      loss_train = loss_fn(output_train, input_train)\n",
    "      loss_epoch += f'\\t {loss_train:.4f}'\n",
    "\n",
    "      output_test = net(input_test)\n",
    "      loss_test = loss_fn(output_test, input_test)\n",
    "      loss_epoch += f'\\t\\t {loss_test:.4f}'\n",
    "\n",
    "    print(loss_epoch)\n",
    "\n",
    "  if verbose:\n",
    "    # Print final loss\n",
    "    loss_mse = f'\\nMSE\\t {eval_mse(output_train, input_train):0.4f}'\n",
    "    loss_mse += f'\\t\\t {eval_mse(output_test, input_test):0.4f}'\n",
    "    print(loss_mse)\n",
    "\n",
    "    loss_bce = f'BCE\\t {eval_bce(output_train, input_train):0.4f}'\n",
    "    loss_bce += f'\\t\\t {eval_bce(output_test, input_test):0.4f}'\n",
    "    print(loss_bce)\n",
    "\n",
    "  # Plot loss\n",
    "  step = int(np.ceil(len(track_loss) / 500))\n",
    "  x_range = np.arange(0, len(track_loss), step)\n",
    "  plt.figure()\n",
    "  plt.plot(x_range, track_loss[::step], 'C0')\n",
    "  plt.xlabel('Iterations')\n",
    "  plt.ylabel('Loss')\n",
    "  plt.xlim([0, None])\n",
    "  plt.ylim([0, None])\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 1: Introduction to autoencoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:38.675459Z",
     "iopub.status.busy": "2021-05-25T00:59:38.674596Z",
     "iopub.status.idle": "2021-05-25T00:59:38.708151Z",
     "shell.execute_reply": "2021-05-25T00:59:38.707641Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Video 2: Autoencoders\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"hefek_yhEKs\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial introduces typical elements of autoencoders, that learn low dimensional representations of data through an auxiliary task of compression and decompression. In general, these networks are characterized by an equal number of input and output units and a *bottleneck layer* with fewer units.\n",
    "\n",
    "![Single hidden layer ANN autoencoder](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/ae-ann-1h.png)\n",
    "\n",
    "Autoencoder architectures have *encoder* and *decoder* components:\n",
    "* The encoder network compresses high dimensional inputs into lower-dimensional coordinates of the *bottleneck layer* \n",
    "* The *decoder* expands *bottleneck layer* coordinates back to the original dimensionality\n",
    "\n",
    "Each input presented to the autoencoder maps to a coordinate in the bottleneck layer that spans the lower-dimensional *latent space*.\n",
    "\n",
    "&nbsp; \n",
    "\n",
    "Differences between inputs and outputs trigger the backpropagation of loss to adjust weights and better compress/decompress data.  Autoencoders are examples of models that automatically build internal representations of the world and use them to predict unseen data.\n",
    "\n",
    "We'll use fully-connected AAN architectures due to their lower computational requirements. The inputs to ANNs are *vectorized* versions of the images (i.e., stretched as a line)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 2: The MNIST dataset\n",
    "The [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database) contains handwritten digits in square images of 28x28 pixels of grayscale levels. There are 60,000 training images and 10,000 testing images from different writers.\n",
    "\n",
    "Get acquainted with the data by inspecting data type, shape, and visualizing samples with the function `plot_row`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Helper function:** `plot_row`\n",
    "\n",
    "Please uncomment the line below to inspect this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:38.711968Z",
     "iopub.status.busy": "2021-05-25T00:59:38.711382Z",
     "iopub.status.idle": "2021-05-25T00:59:38.714484Z",
     "shell.execute_reply": "2021-05-25T00:59:38.714945Z"
    }
   },
   "outputs": [],
   "source": [
    "# help(plot_row)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 2.1: Download  and prepare MNIST dataset\n",
    "We use the helper function `downloadMNIST` to download the dataset, transform it into `torch.Tensor` and assign train and test datasets to (`x_train`, `y_train`) and (`x_test`, `y_test`), respectively.\n",
    "\n",
    "(`x_train`, `x_test`) contain images and (`y_train`, `y_test`) contain labels from `0` to `9`.\n",
    "\n",
    "The original pixel values are integers between `0` and `255`. We rescale them between `0` and `1`, a more favorable range for training the autoencoders in this tutorial.\n",
    "\n",
    "The images are *vectorized*, i.e., stretched as a line. We reshape training and testing images to *vectorized* versions with the method `.reshape` and store them in variable `input_train` and `input_test`, respectively. The variable `image_shape` stores the shape of the images, and `input_size` stores the size of the *vectorized* versions.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below\n",
    "\n",
    "**Questions:**\n",
    "* What are the shape and numeric representations of `x_train` and `input_train`?\n",
    "* What is the image shape?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T00:59:38.720765Z",
     "iopub.status.busy": "2021-05-25T00:59:38.720224Z",
     "iopub.status.idle": "2021-05-25T01:00:07.006834Z",
     "shell.execute_reply": "2021-05-25T01:00:07.007582Z"
    }
   },
   "outputs": [],
   "source": [
    "# Download MNIST\n",
    "x_train, y_train, x_test, y_test = downloadMNIST()\n",
    "\n",
    "x_train = x_train / 255\n",
    "x_test = x_test / 255\n",
    "\n",
    "image_shape = x_train.shape[1:]\n",
    "\n",
    "input_size = np.prod(image_shape)\n",
    "\n",
    "input_train = x_train.reshape([-1, input_size])\n",
    "input_test = x_test.reshape([-1, input_size])\n",
    "\n",
    "test_selected_idx = np.random.choice(len(x_test), 10, replace=False)\n",
    "train_selected_idx = np.random.choice(len(x_train), 10, replace=False)\n",
    "\n",
    "print(f'shape x_train \\t\\t {x_train.shape}')\n",
    "print(f'shape x_test \\t\\t {x_test.shape}')\n",
    "print(f'shape image \\t\\t {image_shape}')\n",
    "print(f'shape input_train \\t {input_train.shape}')\n",
    "print(f'shape input_test \\t {input_test.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize samples\n",
    "The variables `train_selected_idx` and `test_selected_idx` store 10 random indexes from the train and test data.\n",
    "\n",
    "We use the function `np.random.choice` to select 10 indexes from `x_train` and `y_train` without replacement (`replacement=False`).\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cells below\n",
    "* The first cell display different samples each time, the second cell always displays the same samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:07.011237Z",
     "iopub.status.busy": "2021-05-25T01:00:07.010719Z",
     "iopub.status.idle": "2021-05-25T01:00:08.053410Z",
     "shell.execute_reply": "2021-05-25T01:00:08.052934Z"
    }
   },
   "outputs": [],
   "source": [
    "# top row: random images from test set\n",
    "# bottom row: images selected with test_selected_idx\n",
    "\n",
    "plot_row([x_test, x_test[test_selected_idx]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 3: Latent space visualization\n",
    "This section introduces tools for visualization of latent space and applies them to *Principal Component Analysis (PCA)*, already introduced in tutorial *W1D5 Dimensionality reduction*. Please see the exercise in the *Bonus* section for *Non-negative Matrix Factorization (NMF)*.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "The plotting function `plot_latent_generative` helps visualize the encoding of inputs from high dimension into 2D latent space, and decoding back to the original dimension. This function produces two plots:\n",
    "\n",
    "* **Encoder map** shows the mapping from input images to coordinates $(z_1, z_2)$ in latent space, with overlaid digit labels\n",
    "* **Decoder grid** shows reconstructions from a grid of latent space coordinates $(z_1, z_2)$\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![Latent space visualization](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/latent_space_plots_noaxis.png)\n",
    "\n",
    "The latent space representation is a new coordinate system $(z_1, z_2)$ that hopefully captures relevant structure from high-dimensional data. The coordinates of each input only matter relative to those of other inputs, i.e., we often care about separability between different classes of digits rather than their location.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "The encoder map provides direct insight into the organization of latent space. Keep in mind that latent space only contains coordinates $(z_1, z_2)$. We overlay additional information such as digit labels for insight into the latent space structure. \n",
    "\n",
    "The plot on the left is the raw latent space representation corresponding to the plot on the right with digit labels overlaid.\n",
    "\n",
    "![Raw latent space visualization](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/latent_space_plots_nolabel.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.1: MNIST with PCA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 1: Visualize PCA latent space (2D)\n",
    "The tutorial *W1D5 Dimensionality reduction* introduced PCA decomposition. The case of two principal components (PCA1 and PCA2) generates a latent space in 2D.\n",
    "\n",
    "![Latent space visualization PCA](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/latent_space_plots_pca.png)\n",
    "\n",
    "In tutorial W1D5, PCA decomposition was implemented directly and also by using module [sklearn.decomposition](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.decomposition) from the package [scikit-learn](https://scikit-learn.org). This module includes several matrix decomposition algorithms that are useful as dimensionality reduction techniques.\n",
    "\n",
    "Their usage is very straightforward, as shown by this example for truncated SVD:\n",
    "```\n",
    "svd = decomposition.TruncatedSVD(n_components=2)\n",
    "\n",
    "svd.fit(input_train)\n",
    "\n",
    "svd_latent_train = svd.transform(input_train)\n",
    "svd_latent_test = svd.transform(input_test)\n",
    "\n",
    "svd_reconstruction_train = svd.inverse_transform(svd_latent_train)\n",
    "svd_reconstruction_test = svd.inverse_transform(svd_latent_test)\n",
    "```\n",
    "\n",
    "in this exercise, we'll use `decomposition.PCA` (docs [here](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html)) for PCA decomposition.\n",
    "\n",
    "**Instructions:**\n",
    "* Initialize `decomposition.PCA` in 2 dimensions\n",
    "* Fit `input_train` with `.fit` method of `decomposition.PCA`\n",
    "* Obtain latent space representation of `input_test`\n",
    "* Visualize latent space with `plot_latent_generative`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Helper function:** `plot_latent_generative`\n",
    "\n",
    "Please uncomment the line below to inspect this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:08.057962Z",
     "iopub.status.busy": "2021-05-25T01:00:08.057412Z",
     "iopub.status.idle": "2021-05-25T01:00:08.061777Z",
     "shell.execute_reply": "2021-05-25T01:00:08.061287Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title\n",
    "\n",
    "# @markdown Execute this cell to inspect `plot_latent_generative`!\n",
    "help(plot_latent_generative)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:08.066050Z",
     "iopub.status.busy": "2021-05-25T01:00:08.064783Z",
     "iopub.status.idle": "2021-05-25T01:00:08.066685Z",
     "shell.execute_reply": "2021-05-25T01:00:08.067139Z"
    }
   },
   "outputs": [],
   "source": [
    "####################################################\n",
    "## TODO for students: perform PCA and visualize latent space and reconstruction\n",
    "####################################################\n",
    "# create the model\n",
    "# pca = decomposition.PCA(...)\n",
    "# fit the model on training data\n",
    "# pca.fit(...)\n",
    "# transformation on 2D space\n",
    "# pca_latent_test = pca.transform(...)\n",
    "\n",
    "# Uncomment to test your code!\n",
    "# plot_latent_generative(pca_latent_test, y_test, pca.inverse_transform,\n",
    "#                        image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:08.071395Z",
     "iopub.status.busy": "2021-05-25T01:00:08.070838Z",
     "iopub.status.idle": "2021-05-25T01:00:13.593020Z",
     "shell.execute_reply": "2021-05-25T01:00:13.593457Z"
    }
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/Bonus_UnsupervisedLearning/solutions/Tutorial1_Solution_464a2875.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=807 height=416 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/Bonus_UnsupervisedLearning/static/Tutorial1_Solution_464a2875_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.2: Qualitative analysis PCA\n",
    "The encoder map shows how well the encoder is distinguishing between digit classes. We see that digits `1` and `0` are in opposite regions of the first principal component axis, and similarly for digits `9` and `3` for the second principal component.\n",
    "\n",
    "The decoder grid indicates how well the decoder is recovering images from latent space coordinates. Overall, digits `1`, `0`, and `9` are the most recognizable.\n",
    "\n",
    "Let's inspect the principal components to understand these observations better. The principal components are available as `pca.components_` and shown below.\n",
    "\n",
    "![principal components](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/pca-components.png)\n",
    "\n",
    "Notice that the first principal component encodes digit `0` with positive values (in white) and digit `1` in negative values (in black). The colormap encodes the minimum values in black and maximum values in white, and we know their signs by looking at coordinates in the first principal component axis for digits `0` and `1`.\n",
    "\n",
    "&nbsp; \n",
    "\n",
    "The first principal component axis encodes the \"thickness\" of the digits: thin digits on the left and tick digits on the right.\n",
    "\n",
    "Similarly, the second principal component encodes digit `9` with positive values (in white) and digit `3` with negative values (in black).\n",
    "\n",
    "The second principal component axis is encoding, well, another aspect besides \"thickness\" of digits (why?).\n",
    "\n",
    "The reconstruction grid also shows that digits `4` and `7` are indistinguishable from digit `9` and similarly for digits `2` and `3`.\n",
    "\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below\n",
    "* Plot reconstruction samples a few times to get a visual feel of the digit confusions (use keyword `image_shape` to visualize the vectorized images)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:13.610343Z",
     "iopub.status.busy": "2021-05-25T01:00:13.596414Z",
     "iopub.status.idle": "2021-05-25T01:00:13.686591Z",
     "shell.execute_reply": "2021-05-25T01:00:13.687375Z"
    }
   },
   "outputs": [],
   "source": [
    "pca_components = pca.components_\n",
    "\n",
    "plot_row(pca_components, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:13.691077Z",
     "iopub.status.busy": "2021-05-25T01:00:13.689944Z",
     "iopub.status.idle": "2021-05-25T01:00:14.765940Z",
     "shell.execute_reply": "2021-05-25T01:00:14.766350Z"
    }
   },
   "outputs": [],
   "source": [
    "pca_output_test = pca.inverse_transform(pca_latent_test)\n",
    "\n",
    "plot_row([input_test, pca_output_test], image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 4: ANN autoencoder\n",
    "Let's implement a *shallow* ANN autoencoder with a single hidden layer.\n",
    "\n",
    "![Single hidden layer ANN autoencoder](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/ae-ann-1h.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Design ANN autoencoder (32D)\n",
    "Here we introduce a technique for quickly building Pytorch models best suited for the initial exploration phase of your research project. \n",
    "\n",
    "The object-oriented programming (OOP) presented in tutorial W3D4 is your top choice after better understanding the model's architecture and components.\n",
    "\n",
    "Using this concise technique, a network equivalent to `DeepNetReLU` from W3D4 tutorial 1 is defined as:\n",
    "```\n",
    "model = nn.Sequential(nn.Linear(n_input, n_hidden),\n",
    "                      nn.ReLU(),\n",
    "                      nn.Linear(n_hidden, n_output))\n",
    "```\n",
    "Designing and training efficient neural networks currently requires some thought, experience, and testing for choosing between available options, such as the number of hidden layers, loss function, optimizer function, mini-batch size, etc. Choosing these hyper-parameters may soon become more of an engineering process with our increasing analytical understanding of these systems and their learning dynamics.\n",
    "\n",
    "The references below are great to learn more about neural network design and best practices:\n",
    "*  [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com) by Michael Nielsen is an excellent reference for beginners\n",
    "* [Deep Learning](http://www.deeplearningbook.org) by Ian Goodfellow, Yoshua Bengio, and Aaron Courville provides in-depth and extensive coverage\n",
    "* [A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay](https://arxiv.org/abs/1803.09820) by L. Smith covers efficient ways to set hyper-parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 2: Design ANN autoencoder\n",
    "We will use a rectifier [ReLU](https://icml.cc/Conferences/2010/papers/432.pdf) units in the bottleneck layer with `encoding_dim=32` units, and sigmoid units in the output layer. You can read more about activation functions [here](https://en.wikipedia.org/wiki/Activation_function) and rectifiers [here](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)).\n",
    "\n",
    "![ReLU unit](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/relu.png)\n",
    "\n",
    "![Single hidden layer ANN autoencoder](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/sigmoid.png)\n",
    "\n",
    "We rescaled images to values between `0` and `1` for compatibility with sigmoid units in the output (why?). Such mapping is without loss of generality since any (finite) range can map a one-to-one correspondence to values between `0` and `1`.\n",
    "\n",
    "Both ReLU and sigmoid units provide non-linear computation to the encoder and decoder components. The sigmoid units, additionally, ensure output values to be in the same range as the inputs. These units could be swapped by ReLU, in which case output values would sometimes be negative or greater than 1. The sigmoid units of the decoder enforce a numerical constraint that expresses our *domain knowledge* of the data.\n",
    "\n",
    "**Instructions**\n",
    "* `nn.Sequential` defines and initializes an ANN with layer sizes (`input_shape, encoding_dim, input_shape`)\n",
    "* `nn.Linear` defines a linear layer with the size of the inputs and outputs as arguments\n",
    "* `nn.ReLU` and `nn.Sigmoid` encode ReLU and sigmoid units\n",
    "* Visualize the initial output using `plot_row` with input and output images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:14.771473Z",
     "iopub.status.busy": "2021-05-25T01:00:14.770377Z",
     "iopub.status.idle": "2021-05-25T01:00:14.774704Z",
     "shell.execute_reply": "2021-05-25T01:00:14.774206Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 32\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.ReLU(),\n",
    "    ######################################################################\n",
    "    ## TODO for students: add linear and sigmoid layers\n",
    "    ######################################################################\n",
    "    # insert your code here to add the layer\n",
    "    # nn.Linear(...),\n",
    "    # insert the activation function\n",
    "    # ....\n",
    "    )\n",
    "\n",
    "print(f'Model structure \\n\\n {model}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:14.779773Z",
     "iopub.status.busy": "2021-05-25T01:00:14.778264Z",
     "iopub.status.idle": "2021-05-25T01:00:14.781806Z",
     "shell.execute_reply": "2021-05-25T01:00:14.782263Z"
    }
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/Bonus_UnsupervisedLearning/solutions/Tutorial1_Solution_e7182519.py)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**SAMPLE OUTPUT**\n",
    "\n",
    "```\n",
    "Sequential(\n",
    "  (0): Linear(in_features=784, out_features=32, bias=True)\n",
    "  (1): ReLU()\n",
    "  (2): Linear(in_features=32, out_features=784, bias=True)\n",
    "  (3): Sigmoid()\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:14.790598Z",
     "iopub.status.busy": "2021-05-25T01:00:14.790044Z",
     "iopub.status.idle": "2021-05-25T01:00:15.896275Z",
     "shell.execute_reply": "2021-05-25T01:00:15.896714Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test.float(), output_test], image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train autoencoder (32D)\n",
    "The function `runSGD` trains the autoencoder with stochastic gradient descent using Adam optimizer (`optim.Adam`) and provides a choice between Mean Square Errors (MSE  with `nn.MSELoss`) and Binary Cross-entropy (BCE with `nn.BCELoss`).\n",
    "\n",
    "The figures below illustrate these losses, where $\\hat{Y}$ is the output value, and $Y$ is the target value.\n",
    "\n",
    "![MSE loss](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/mse.png)\n",
    "\n",
    "![BCE loss](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/bce.png)\n",
    "\n",
    "Train the network for `n_epochs=10` epochs and `batch_size=64` with `runSGD` and MSE loss, and visualize a few reconstructed samples.\n",
    "\n",
    "Please execute the cells below to construct and train the model!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:15.902327Z",
     "iopub.status.busy": "2021-05-25T01:00:15.901000Z",
     "iopub.status.idle": "2021-05-25T01:00:32.722731Z",
     "shell.execute_reply": "2021-05-25T01:00:32.722222Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 32\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "n_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "runSGD(model, input_train, input_test, criterion='mse',\n",
    "       n_epochs=n_epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:32.736760Z",
     "iopub.status.busy": "2021-05-25T01:00:32.736220Z",
     "iopub.status.idle": "2021-05-25T01:00:33.767198Z",
     "shell.execute_reply": "2021-05-25T01:00:33.767961Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the loss function\n",
    "The loss function determines what the network is optimizing during training, and this translates to the visual aspect of reconstructed images. \n",
    "\n",
    "For example, isolated black pixels in the middle of white regions are very unlikely and look noisy. The network can prioritize avoiding such scenarios by maximally penalizing white pixels that turn out black and vice-versa.\n",
    "\n",
    "The figure below compares MSE with BCE with a target pixel value $Y=1$, and the output ranging from $\\hat{Y}\\in [0, 1]$. The MSE loss has a gentle quadratic rise in this range. Notice how BCE loss dramatically increases for dark pixels $\\hat{Y}$ lower than 0.4.\n",
    "\n",
    "![bce vs. MSE loss](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/bce-mse.png)\n",
    "\n",
    "Let's look at their derivatives $d\\,\\text{Loss}/d\\,\\hat{Y}$ to make this comparison more objective. The derivative of MSE loss is linear with slope $-2$, whereas BCE takes off as $1/\\hat{Y}$ for dark pixel values (why?). \n",
    "\n",
    "![bce vs. MSE loss](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/bce-mse-dloss.png)\n",
    "\n",
    "We reduced the plotting range to $[0.05, 1]$ to share the same y-axis scale for both loss functions (why?).\n",
    "\n",
    "Let's switch to BCE loss and verify the effects of maximally penalizing white pixels that turn out black and vice-versa. The visual differences between losses will be subtle since the network is converging well in both cases.\n",
    "\n",
    "**Look for isolated white/black pixel areas in MSE loss reconstructions.**\n",
    "\n",
    "We will first retrain under MSE loss for `2` epochs to accentuate differences, and similarly under BCE loss.\n",
    "\n",
    "Please execute the cells below to train with MSE and BCE, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:33.773328Z",
     "iopub.status.busy": "2021-05-25T01:00:33.772083Z",
     "iopub.status.idle": "2021-05-25T01:00:37.478551Z",
     "shell.execute_reply": "2021-05-25T01:00:37.478075Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 32\n",
    "n_epochs = 2\n",
    "batch_size = 64\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "runSGD(model, input_train, input_test, criterion='mse',\n",
    "       n_epochs=n_epochs, batch_size=batch_size, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:37.484169Z",
     "iopub.status.busy": "2021-05-25T01:00:37.483648Z",
     "iopub.status.idle": "2021-05-25T01:00:38.542477Z",
     "shell.execute_reply": "2021-05-25T01:00:38.542015Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:38.547724Z",
     "iopub.status.busy": "2021-05-25T01:00:38.546510Z",
     "iopub.status.idle": "2021-05-25T01:00:44.070541Z",
     "shell.execute_reply": "2021-05-25T01:00:44.069739Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 32\n",
    "n_epochs = 2\n",
    "batch_size = 64\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "runSGD(model, input_train, input_test, criterion='bce',\n",
    "       n_epochs=n_epochs, batch_size=batch_size, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:44.076461Z",
     "iopub.status.busy": "2021-05-25T01:00:44.075935Z",
     "iopub.status.idle": "2021-05-25T01:00:45.052029Z",
     "shell.execute_reply": "2021-05-25T01:00:45.051538Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Design ANN autoencoder (2D)\n",
    "Reducing the number of bottleneck units to `encoding_size=2` generates a 2D latent space as for PCA before. The coordinates $(z_1, z_2)$  of the encoder map represent unit activations in the bottleneck layer.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![encoder map for autoencoder](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/latent_space_plots_ae.png)\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "The `encoder` component provides ($z_1, z_2$) coordinates in latent space, and the `decoder` component generates image reconstructions from ($z_1, z_2$). Specifying a sequence of layers from the autoencoder network defines these sub-networks.\n",
    "\n",
    "```\n",
    "model = nn.Sequential(...)\n",
    "encoder = model[:n]\n",
    "decoder = model[n:]\n",
    "```\n",
    "\n",
    "This architecture works well with a bottleneck layer with 32 units but fails to converge with two units.  Check the exercises in *Bonus* section to understand this failure more and two options to address it: better weight initialization and changing the activation function.\n",
    "\n",
    "Here we opt for [PReLU units](https://arxiv.org/abs/1502.01852) in the bottleneck layer to add negative activations with a learnable parameter. This change affords additional wiggle room for the autoencoder to model data with only two units in the bottleneck layer.\n",
    "\n",
    "![PreLU unit](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/prelu.png)\n",
    "\n",
    "**Instructions**\n",
    "* Please execute the cells below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:45.057831Z",
     "iopub.status.busy": "2021-05-25T01:00:45.057231Z",
     "iopub.status.idle": "2021-05-25T01:00:45.060108Z",
     "shell.execute_reply": "2021-05-25T01:00:45.060564Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 2\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "encoder = model[:2]\n",
    "decoder = model[2:]\n",
    "\n",
    "print(f'Autoencoder \\n\\n {model}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:45.066533Z",
     "iopub.status.busy": "2021-05-25T01:00:45.064307Z",
     "iopub.status.idle": "2021-05-25T01:00:45.068405Z",
     "shell.execute_reply": "2021-05-25T01:00:45.067912Z"
    }
   },
   "outputs": [],
   "source": [
    "print(f'Encoder \\n\\n {encoder}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:45.072567Z",
     "iopub.status.busy": "2021-05-25T01:00:45.072031Z",
     "iopub.status.idle": "2021-05-25T01:00:45.073881Z",
     "shell.execute_reply": "2021-05-25T01:00:45.074299Z"
    }
   },
   "outputs": [],
   "source": [
    "print(f'Decoder \\n\\n {decoder}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the autoencoder (2D)\n",
    "Train the network for `n_epochs=10` epochs and `batch_size=64` with `runSGD` and BCE loss, and visualize latent space.\n",
    "\n",
    "Please execute the cells below to train the autoencoder!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:00:45.078252Z",
     "iopub.status.busy": "2021-05-25T01:00:45.077683Z",
     "iopub.status.idle": "2021-05-25T01:01:03.864816Z",
     "shell.execute_reply": "2021-05-25T01:01:03.864347Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# train the autoencoder\n",
    "runSGD(model, input_train, input_test, criterion='bce',\n",
    "       n_epochs=n_epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:03.873087Z",
     "iopub.status.busy": "2021-05-25T01:01:03.872556Z",
     "iopub.status.idle": "2021-05-25T01:01:06.107450Z",
     "shell.execute_reply": "2021-05-25T01:01:06.107920Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder,\n",
    "                       image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Expressive power in 2D\n",
    "The latent space representation of shallow autoencoder with a 2D bottleneck is similar to that of PCA. How can this linear dimensionality reduction technique be comparable to our non-linear autoencoder?\n",
    "\n",
    "Training an autoencoder with linear activation functions under MSE loss is [very similar to performing PCA](https://arxiv.org/abs/1804.10253). Using piece-wise linear units, sigmoidal output unit, and BCE loss doesn't seem to change this behavior qualitatively. The network lacks capacity in terms of learnable parameters to make good use of its non-linear operations and capture non-linear aspects of the data.\n",
    "\n",
    "The similarity between representations is apparent when plotting decoder maps side-by-side. Look for classes of digits that cluster successfully, and those still mixing with others. \n",
    "\n",
    "Execute the cell below for a PCA vs. Autoencoder (2D) comparison!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:06.111850Z",
     "iopub.status.busy": "2021-05-25T01:01:06.111290Z",
     "iopub.status.idle": "2021-05-25T01:01:09.906107Z",
     "shell.execute_reply": "2021-05-25T01:01:09.906585Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_latent_ab(pca_latent_test, latent_test, y_test,\n",
    "               title_a='PCA', title_b='Autoencoder (2D)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Summary\n",
    "In this tutorial, we got comfortable with the basic techniques to create and visualize low-dimensional representations and build shallow autoencoders. \n",
    "\n",
    "**We saw that PCA and shallow autoencoder have similar expressive power in 2D latent space, despite the autoencoder's non-linear character.**\n",
    "\n",
    "The shallow autoencoder lacks learnable parameters to take advantage of non-linear operations in encoding/decoding and capture non-linear patterns in data.\n",
    "\n",
    "The next tutorial extends the autoencoder architecture to learn richer internal representations of data required for tackling the MNIST cognitive task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:09.912251Z",
     "iopub.status.busy": "2021-05-25T01:01:09.911716Z",
     "iopub.status.idle": "2021-05-25T01:01:09.964417Z",
     "shell.execute_reply": "2021-05-25T01:01:09.964789Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Video 3: Wrap-up\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"V0gVrkyFd0Y\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Failure mode with ReLU units in 2D\n",
    "An architecture with two units in the bottleneck layer, ReLU units, and default weight initialization may fail to converge, depending on the minibatch sequence, choice of the optimizer, etc. To illustrate this failure mode, we first set the random number generators (RNGs) to reproduce an example of failed convergence:\n",
    "```\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "```\n",
    "Afterward, we set the RNGs to reproduce an example of successful convergence:\n",
    "\n",
    "```\n",
    "torch.manual_seed(1)\n",
    "```\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "Train the network for `n_epochs=10` epochs and `batch_size=64` and check the encoder map and reconstruction grid in each case.\n",
    "\n",
    "We then activate our x-ray vision and check the distribution of weights in encoder and decoder components. Recall that encoder maps input pixels to bottleneck units (encoder weights `shape=(2, 784)`), and decoder maps bottleneck units to output pixels (decoder weights `shape=(784, 2)`).\n",
    "\n",
    "Network models often initialize with random weights close to 0. The default weight initialization for linear layers in Pytorch is sampled from a uniform distribution `[-limit, limit]` with `limit=1/sqrt(fan_in)`, where `fan_in` is the number of input units in the weight tensor.\n",
    "\n",
    "We compare the distribution of weights on network initialization to that after training. Weights that fail to learn during training keep to their initial distribution. On the other hand, weights that are adjusted by SGD during training are likely to have a change in distribution.\n",
    "\n",
    "Encoder weights may even acquire a bell-shaped form. This effect may be related to the following: SGD adds a sequence of positive and negative increments to each initial weight. The Central Limit Theorem (CLT) would predict a gaussian histogram if increments were independent in sequences and between sequences. The deviation from gaussianity is a measure of the inter-dependency of SGD increments.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cells below\n",
    "* Start with `torch.manual_seed = 0` for an example of failed convergence\n",
    "* Check encoder mapping collapsed into a single axis\n",
    "* Verify collapsed dimension corresponds to unchanged weights\n",
    "* Change `torch.manual_seed = 1` for an example of successful convergence\n",
    "* Run `help(get_layer_weights)` for additional details on retrieving learnable parameters (weights and biases)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:09.971718Z",
     "iopub.status.busy": "2021-05-25T01:01:09.971156Z",
     "iopub.status.idle": "2021-05-25T01:01:28.584333Z",
     "shell.execute_reply": "2021-05-25T01:01:28.583822Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 2\n",
    "n_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# set PyTorch RNG seed\n",
    "torch_seed = 0\n",
    "\n",
    "# reset RNG for weight initialization\n",
    "torch.manual_seed(torch_seed)\n",
    "np.random.seed(0)\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "encoder = model[:2]\n",
    "decoder = model[2:]\n",
    "\n",
    "# retrieve weights and biases from the encoder before training\n",
    "encoder_w_init, encoder_b_init = get_layer_weights(encoder[0])\n",
    "decoder_w_init, decoder_b_init = get_layer_weights(decoder[0])\n",
    "\n",
    "# reset RNG for minibatch sequence\n",
    "torch.manual_seed(torch_seed)\n",
    "np.random.seed(0)\n",
    "\n",
    "# train the autoencoder\n",
    "runSGD(model, input_train, input_test, criterion='bce',\n",
    "       n_epochs=n_epochs, batch_size=batch_size)\n",
    "\n",
    "# retrieve weights and biases from the encoder after training\n",
    "encoder_w_train, encoder_b_train = get_layer_weights(encoder[0])\n",
    "decoder_w_train, decoder_b_train = get_layer_weights(decoder[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:28.589176Z",
     "iopub.status.busy": "2021-05-25T01:01:28.588393Z",
     "iopub.status.idle": "2021-05-25T01:01:31.988055Z",
     "shell.execute_reply": "2021-05-25T01:01:31.988494Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder, image_shape=image_shape)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:32.010916Z",
     "iopub.status.busy": "2021-05-25T01:01:32.007930Z",
     "iopub.status.idle": "2021-05-25T01:01:33.122275Z",
     "shell.execute_reply": "2021-05-25T01:01:33.122685Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_weights_ab(encoder_w_init, encoder_w_train, decoder_w_init,\n",
    "                decoder_w_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 3: Choosing weight initialization\n",
    "An improved weight initialization for ReLU units avoids the failure mode from the previous exercise. A popular choice for rectifier units is *Kaiming uniform*: sampling from uniform distribution $\\mathcal{U}(-limit, limit)$ with $limit=\\sqrt{6/fan\\_in}$, where $fan\\_in$ is the number of input units in the weight tensor (see the relevant [article](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf) for details). Example of resetting all autoencoder weights to Kaiming uniform:\n",
    "```\n",
    "model.apply(init_weights_kaiming_uniform)\n",
    "```\n",
    "An alternative is to sample from a gaussian distribution $\\mathcal{N}(\\mu, \\sigma^2)$ with $\\mu=0$ and $\\sigma=1/\\sqrt{fan\\_in}$. Example for reseting all but the two last autoencoder layers to Kaiming normal:\n",
    "```\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "```\n",
    "\n",
    "For more information on weight initialization, the references below are a good starting point:\n",
    "* [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)\n",
    "* [Understanding the difficulty of training deep feedforward neural networks](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)\n",
    "* [Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)\n",
    "\n",
    "**Instructions:**\n",
    "* Reset encoder weights with `init_weights_kaiming_uniform`\n",
    "* Compare with resetting with `init_weights_kaiming_normal`\n",
    "* See `help(init_weights_kaiming_uniform)` and `help(init_weights_kaiming_normal)` for additional details"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:33.130199Z",
     "iopub.status.busy": "2021-05-25T01:01:33.129658Z",
     "iopub.status.idle": "2021-05-25T01:01:52.163030Z",
     "shell.execute_reply": "2021-05-25T01:01:52.163470Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 2\n",
    "n_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# set PyTorch RNG seed\n",
    "torch_seed = 0\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "encoder = model[:2]\n",
    "decoder = model[2:]\n",
    "\n",
    "# reset RNGs for weight initialization\n",
    "torch.manual_seed(torch_seed)\n",
    "np.random.seed(0)\n",
    "\n",
    "######################################################################\n",
    "## TODO for students: reset encoder weights and biases\n",
    "######################################################################\n",
    "# reset encoder weights and biases\n",
    "# encoder.apply(...)\n",
    "\n",
    "# retrieve weights and biases from the encoder before training\n",
    "encoder_w_init, encoder_b_init = get_layer_weights(encoder[0])\n",
    "decoder_w_init, decoder_b_init = get_layer_weights(decoder[0])\n",
    "\n",
    "# reset RNGs for minibatch sequence\n",
    "torch.manual_seed(torch_seed)\n",
    "np.random.seed(0)\n",
    "\n",
    "# train the autoencoder\n",
    "runSGD(model, input_train, input_test, criterion='bce',\n",
    "       n_epochs=n_epochs, batch_size=batch_size)\n",
    "\n",
    "# retrieve weights and biases from the encoder after training\n",
    "encoder_w_train, encoder_b_train = get_layer_weights(encoder[0])\n",
    "decoder_w_train, decoder_b_train = get_layer_weights(decoder[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:01:52.171192Z",
     "iopub.status.busy": "2021-05-25T01:01:52.170587Z",
     "iopub.status.idle": "2021-05-25T01:02:10.918719Z",
     "shell.execute_reply": "2021-05-25T01:02:10.919182Z"
    }
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/Bonus_UnsupervisedLearning/solutions/Tutorial1_Solution_9d6c1017.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=558 height=413 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/Bonus_UnsupervisedLearning/static/Tutorial1_Solution_9d6c1017_11.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:02:10.927857Z",
     "iopub.status.busy": "2021-05-25T01:02:10.927304Z",
     "iopub.status.idle": "2021-05-25T01:02:14.333483Z",
     "shell.execute_reply": "2021-05-25T01:02:14.332994Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder,\n",
    "                       image_shape=image_shape)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:02:14.355245Z",
     "iopub.status.busy": "2021-05-25T01:02:14.338530Z",
     "iopub.status.idle": "2021-05-25T01:02:15.511353Z",
     "shell.execute_reply": "2021-05-25T01:02:15.510870Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_weights_ab(encoder_w_init, encoder_w_train, decoder_w_init,\n",
    "                decoder_w_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the activation function\n",
    "An alternative to specific weight initialization is to choose an activation unit that performs better in this context. We will use [PReLU](https://arxiv.org/abs/1502.01852) units in the bottleneck layer, which adds a learnable parameter for negative activations. \n",
    "\n",
    "This change affords a little bit more of wiggle room for the autoencoder to model data compared to ReLU units.\n",
    "\n",
    "![PreLU unit](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/prelu.png)\n",
    "\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cells below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:02:15.518191Z",
     "iopub.status.busy": "2021-05-25T01:02:15.517643Z",
     "iopub.status.idle": "2021-05-25T01:02:34.262790Z",
     "shell.execute_reply": "2021-05-25T01:02:34.262303Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 2\n",
    "n_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# set PyTorch RNG seed\n",
    "torch_seed = 0\n",
    "\n",
    "# reset RNGs for weight initialization\n",
    "torch.manual_seed(torch_seed)\n",
    "np.random.seed(0)\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, encoding_size),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size, input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "encoder = model[:2]\n",
    "decoder = model[2:]\n",
    "\n",
    "# retrieve weights and biases from the encoder before training\n",
    "encoder_w_init, encoder_b_init = get_layer_weights(encoder[0])\n",
    "decoder_w_init, decoder_b_init = get_layer_weights(decoder[0])\n",
    "\n",
    "# reset RNGs for minibatch sequence\n",
    "torch.manual_seed(torch_seed)\n",
    "np.random.seed(0)\n",
    "\n",
    "# train the autoencoder\n",
    "runSGD(model, input_train, input_test, criterion='bce',\n",
    "       n_epochs=n_epochs, batch_size=batch_size)\n",
    "\n",
    "# retrieve weights and biases from the encoder after training\n",
    "encoder_w_train, encoder_b_train = get_layer_weights(encoder[0])\n",
    "decoder_w_train, decoder_b_train = get_layer_weights(decoder[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:02:34.270903Z",
     "iopub.status.busy": "2021-05-25T01:02:34.270341Z",
     "iopub.status.idle": "2021-05-25T01:02:36.615862Z",
     "shell.execute_reply": "2021-05-25T01:02:36.616305Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:02:36.638450Z",
     "iopub.status.busy": "2021-05-25T01:02:36.633852Z",
     "iopub.status.idle": "2021-05-25T01:02:37.718507Z",
     "shell.execute_reply": "2021-05-25T01:02:37.718950Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_weights_ab(encoder_w_init, encoder_w_train, decoder_w_init,\n",
    "                decoder_w_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Qualitative analysis NMF\n",
    "We proceed with *non-negative matrix factorization (NMF)* using `sk.decomposition.NMF` (docs [here](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html)).\n",
    "\n",
    "A product of positive matrices $W$ and $H$ approximates data matrix $X$, i.e., $X \\approx W H$.\n",
    "\n",
    "The columns of $W$ play the same role as the principal components in PCA.\n",
    "\n",
    "Digit classes `0` and `1` are the furthest apart in latent space and better clustered.\n",
    "\n",
    "Looking at the first component, we see that images gradually resemble digit class `0`. A mix between digits classes `1` and `9` in the second component shows a similar progression.\n",
    "\n",
    "That data is shifted by `0.5` to avoid failure modes near `0` - this is probably related to our scaling choice. Try it without shifting by `0.5`.\n",
    "\n",
    "The parameter `init='random'` scales the initial non-negative random matrices and often provides better results - try it as well!\n",
    "\n",
    "Please execute the cells below, to run NMF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:02:37.723315Z",
     "iopub.status.busy": "2021-05-25T01:02:37.722749Z",
     "iopub.status.idle": "2021-05-25T01:03:00.622006Z",
     "shell.execute_reply": "2021-05-25T01:03:00.622418Z"
    }
   },
   "outputs": [],
   "source": [
    "nmf = decomposition.NMF(n_components=2, init='random')\n",
    "\n",
    "nmf.fit(input_train + 0.5)\n",
    "\n",
    "nmf_latent_test = nmf.transform(input_test + 0.5)\n",
    "\n",
    "plot_latent_generative(nmf_latent_test, y_test, nmf.inverse_transform,\n",
    "                       image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:00.649022Z",
     "iopub.status.busy": "2021-05-25T01:03:00.639129Z",
     "iopub.status.idle": "2021-05-25T01:03:00.714985Z",
     "shell.execute_reply": "2021-05-25T01:03:00.715740Z"
    }
   },
   "outputs": [],
   "source": [
    "nmf_components = nmf.components_\n",
    "\n",
    "plot_row(nmf_components, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:00.739094Z",
     "iopub.status.busy": "2021-05-25T01:03:00.738547Z",
     "iopub.status.idle": "2021-05-25T01:03:01.812275Z",
     "shell.execute_reply": "2021-05-25T01:03:01.811747Z"
    }
   },
   "outputs": [],
   "source": [
    "nmf_output_test = nmf.inverse_transform(nmf_latent_test)\n",
    "\n",
    "plot_row([input_test, nmf_output_test], image_shape=image_shape)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "Tutorial1",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
